import fitz
import os
import numpy as np
import json
from openai import OpenAI
import re
from tqdm import tqdm

# 设置 API 变量
os.environ["OPENAI_API_KEY"] = "sk-aswwzdvvimeybiiokqebixpkhbmcftlbgkubssfuodifqjcf"
# 初始化OpenAI客户端
client = OpenAI(
    base_url="https://api.siliconflow.cn/v1",
    api_key=os.getenv("OPENAI_API_KEY")  # 获取环境变量
)

def extract_text_from_pdf(pdf_path):
    mypdf = fitz.open(pdf_path)
    all_text = ""
    for page_num in range(mypdf.page_count):
        page = mypdf[page_num]  # Get the page
        text = page.get_text("text")  # Extract text from the page
        all_text += text  # Append the extracted text to the all_text string
    return all_text  # Return the extracted text
def chunk_text(text, n=1000, overlap=200):
    chunks = []
    for i in range(0, len(text), n - overlap):
        chunks.append(text[i:i + n])
    return chunks
class SimpleVectorStore:
    def __init__(self):
        self.vectors = []
        self.texts = []
        self.metadata = []
    def add_item(self, text, embedding, metadata=None):
        self.vectors.append(np.array(embedding))  # Convert embedding to numpy array and add to vectors list
        self.texts.append(text)  # Add the original text to texts list
        self.metadata.append(metadata or {})  # Add metadata to metadata list, use empty dict if None
    def similarity_search(self, query_embedding, k=5):
        if not self.vectors:
            return []
        query_vector = np.array(query_embedding)
        similarities = []
        for i, vector in enumerate(self.vectors):
            similarity = np.dot(query_vector, vector) / (np.linalg.norm(query_vector) * np.linalg.norm(vector))
            similarities.append((i, similarity))
        similarities.sort(key=lambda x: x[1], reverse=True)
        results = []
        for i in range(min(k, len(similarities))):
            idx, score = similarities[i]
            results.append({
                "text": self.texts[idx],  # Add the text corresponding to the index
                "metadata": self.metadata[idx],  # Add the metadata corresponding to the index
                "similarity": score  # Add the similarity score
            })
        return results  # Return the list of top k results
def create_embeddings(text, model="BAAI/bge-m3"):
    input_text = text if isinstance(text, list) else [text]
    response = client.embeddings.create(
        model=model,
        input=input_text
    )
    if isinstance(text, str):
        return response.data[0].embedding
    return [item.embedding for item in response.data]
def process_document(pdf_path, chunk_size=1000, chunk_overlap=200):
    print("Extracting text from PDF...")
    extracted_text = extract_text_from_pdf(pdf_path)
    print("Chunking text...")
    chunks = chunk_text(extracted_text, chunk_size, chunk_overlap)
    print(f"Created {len(chunks)} text chunks")
    print("Creating embeddings for chunks...")
    chunk_embeddings = create_embeddings(chunks)
    store = SimpleVectorStore()
    for i, (chunk, embedding) in enumerate(zip(chunks, chunk_embeddings)):
        store.add_item(
            text=chunk,
            embedding=embedding,
            metadata={"index": i, "source": pdf_path}
        )
    print(f"Added {len(chunks)} chunks to the vector store")
    return store
def compress_chunk(chunk, query, compression_type="selective", model="Qwen/Qwen2-1.5B-Instruct"):
    if compression_type == "selective":
        system_prompt = """You are an expert at information filtering. 
        Your task is to analyze a document chunk and extract ONLY the sentences or paragraphs that are directly 
        relevant to the user's query. Remove all irrelevant content.
        Your output should:
        1. ONLY include text that helps answer the query
        2. Preserve the exact wording of relevant sentences (do not paraphrase)
        3. Maintain the original order of the text
        4. Include ALL relevant content, even if it seems redundant
        5. EXCLUDE any text that isn't relevant to the query
        Format your response as plain text with no additional comments."""
    elif compression_type == "summary":
        system_prompt = """You are an expert at summarization. 
        Your task is to create a concise summary of the provided chunk that focuses ONLY on 
        information relevant to the user's query.
        Your output should:
        1. Be brief but comprehensive regarding query-relevant information
        2. Focus exclusively on information related to the query
        3. Omit irrelevant details
        4. Be written in a neutral, factual tone
        Format your response as plain text with no additional comments."""
    else:
        system_prompt = """You are an expert at information extraction.
        Your task is to extract ONLY the exact sentences from the document chunk that contain information relevant 
        to answering the user's query.
        Your output should:
        1. Include ONLY direct quotes of relevant sentences from the original text
        2. Preserve the original wording (do not modify the text)
        3. Include ONLY sentences that directly relate to the query
        4. Separate extracted sentences with newlines
        5. Do not add any commentary or additional text
        Format your response as plain text with no additional comments."""
    user_prompt = f"""
        Query: {query}
        Document Chunk:
        {chunk}
        Extract only the content relevant to answering this query.
    """
    response = client.chat.completions.create(
        model=model,
        messages=[
            {"role": "system", "content": system_prompt},
            {"role": "user", "content": user_prompt}
        ],
        temperature=0
    )
    compressed_chunk = response.choices[0].message.content.strip()
    original_length = len(chunk)
    compressed_length = len(compressed_chunk)
    compression_ratio = (original_length - compressed_length) / original_length * 100
    return compressed_chunk, compression_ratio
def batch_compress_chunks(chunks, query, compression_type="selective", model="Qwen/Qwen2-1.5B-Instruct"):
    print(f"Compressing {len(chunks)} chunks...")  # Print the number of chunks to be compressed
    results = []  # Initialize an empty list to store the results
    total_original_length = 0
    total_compressed_length = 0
    for i, chunk in enumerate(chunks):
        print(f"Compressing chunk {i + 1}/{len(chunks)}...")
        compressed_chunk, compression_ratio = compress_chunk(chunk, query, compression_type, model)
        results.append((compressed_chunk, compression_ratio))  # Append the result to the results list
        total_original_length += len(chunk)  # Add the length of the original chunk to the total original length
        total_compressed_length += len(
            compressed_chunk)
    overall_ratio = (total_original_length - total_compressed_length) / total_original_length * 100
    print(f"Overall compression ratio: {overall_ratio:.2f}%")  # Print the overall compression ratio
    return results
def generate_response(query, context, model="Qwen/Qwen2-1.5B-Instruct"):
    system_prompt = """You are a helpful AI assistant. Answer the user's question based only on the provided context.
    If you cannot find the answer in the context, state that you don't have enough information."""
    user_prompt = f"""
        Context:
        {context}
        Question: {query}
        Please provide a comprehensive answer based only on the context above.
    """
    response = client.chat.completions.create(
        model=model,
        messages=[
            {"role": "system", "content": system_prompt},
            {"role": "user", "content": user_prompt}
        ],
        temperature=0
    )
    return response.choices[0].message.content
def rag_with_compression(pdf_path, query, k=10, compression_type="selective", model="Qwen/Qwen2-1.5B-Instruct"):
    print("\n=== RAG WITH CONTEXTUAL COMPRESSION ===")
    print(f"Query: {query}")
    print(f"Compression type: {compression_type}")
    vector_store = process_document(pdf_path)
    query_embedding = create_embeddings(query)
    print(f"Retrieving top {k} chunks...")
    results = vector_store.similarity_search(query_embedding, k=k)
    retrieved_chunks = [result["text"] for result in results]
    compressed_results = batch_compress_chunks(retrieved_chunks, query, compression_type, model)
    compressed_chunks = [result[0] for result in compressed_results]
    compression_ratios = [result[1] for result in compressed_results]
    filtered_chunks = [(chunk, ratio) for chunk, ratio in zip(compressed_chunks, compression_ratios) if chunk.strip()]
    if not filtered_chunks:
        print("Warning: All chunks were compressed to empty strings. Using original chunks.")
        filtered_chunks = [(chunk, 0.0) for chunk in retrieved_chunks]
    else:
        compressed_chunks, compression_ratios = zip(*filtered_chunks)
    context = "\n\n---\n\n".join(compressed_chunks)
    print("Generating response based on compressed chunks...")
    response = generate_response(query, context, model)
    result = {
        "query": query,
        "original_chunks": retrieved_chunks,
        "compressed_chunks": compressed_chunks,
        "compression_ratios": compression_ratios,
        "context_length_reduction": f"{sum(compression_ratios) / len(compression_ratios):.2f}%",
        "response": response
    }
    print("\n=== RESPONSE ===")
    print(response)
    return result
def standard_rag(pdf_path, query, k=10, model="Qwen/Qwen2-1.5B-Instruct"):
    print("\n=== STANDARD RAG ===")
    print(f"Query: {query}")
    vector_store = process_document(pdf_path)
    query_embedding = create_embeddings(query)
    print(f"Retrieving top {k} chunks...")
    results = vector_store.similarity_search(query_embedding, k=k)
    retrieved_chunks = [result["text"] for result in results]
    context = "\n\n---\n\n".join(retrieved_chunks)
    print("Generating response...")
    response = generate_response(query, context, model)
    result = {
        "query": query,
        "chunks": retrieved_chunks,
        "response": response
    }
    print("\n=== RESPONSE ===")
    print(response)
    return result
def evaluate_responses(query, responses, reference_answer):
    system_prompt = """You are an objective evaluator of RAG responses. Compare different responses to the same query
    and determine which is most accurate, comprehensive, and relevant to the query."""
    user_prompt = f"""
    Query: {query}
    Reference Answer: {reference_answer}
    """
    for method, response in responses.items():
        user_prompt += f"\n{method.capitalize()} Response:\n{response}\n"
    user_prompt += """
    Please evaluate these responses based on:
    1. Factual accuracy compared to the reference
    2. Comprehensiveness - how completely they answer the query
    3. Conciseness - whether they avoid irrelevant information
    4. Overall quality
    Rank the responses from best to worst with detailed explanations.
    """
    evaluation_response = client.chat.completions.create(
        model="Qwen/Qwen2-1.5B-Instruct",
        messages=[
            {"role": "system", "content": system_prompt},
            {"role": "user", "content": user_prompt}
        ],
        temperature=0
    )
    return evaluation_response.choices[0].message.content
def evaluate_compression(pdf_path, query, reference_answer=None,
                         compression_types=["selective", "summary", "extraction"]):
    print("\n=== EVALUATING CONTEXTUAL COMPRESSION ===")
    print(f"Query: {query}")
    standard_result = standard_rag(pdf_path, query)
    compression_results = {}
    for comp_type in compression_types:
        print(f"\nTesting {comp_type} compression...")
        compression_results[comp_type] = rag_with_compression(pdf_path, query, compression_type=comp_type)
    responses = {
        "standard": standard_result["response"]
    }
    for comp_type in compression_types:
        responses[comp_type] = compression_results[comp_type]["response"]
    if reference_answer:
        evaluation = evaluate_responses(query, responses, reference_answer)
        print("\n=== EVALUATION RESULTS ===")
        print(evaluation)
    else:
        evaluation = "No reference answer provided for evaluation."
    metrics = {}
    for comp_type in compression_types:
        metrics[comp_type] = {
            "avg_compression_ratio": f"{sum(compression_results[comp_type]['compression_ratios']) / len(compression_results[comp_type]['compression_ratios']):.2f}%",
            "total_context_length": len("\n\n".join(compression_results[comp_type]['compressed_chunks'])),
            "original_context_length": len("\n\n".join(standard_result['chunks']))
        }
    return {
        "query": query,
        "responses": responses,
        "evaluation": evaluation,
        "metrics": metrics,
        "standard_result": standard_result,
        "compression_results": compression_results
    }
def visualize_compression_results(evaluation_results):
    query = evaluation_results["query"]
    standard_chunks = evaluation_results["standard_result"]["chunks"]
    print(f"Query: {query}")
    print("\n" + "=" * 80 + "\n")
    original_chunk = standard_chunks[0]
    for comp_type in evaluation_results["compression_results"].keys():
        compressed_chunks = evaluation_results["compression_results"][comp_type]["compressed_chunks"]
        compression_ratios = evaluation_results["compression_results"][comp_type]["compression_ratios"]
        compressed_chunk = compressed_chunks[0]
        compression_ratio = compression_ratios[0]
        print(f"\n=== {comp_type.upper()} COMPRESSION EXAMPLE ===\n")
        print("ORIGINAL CHUNK:")
        print("-" * 40)
        if len(original_chunk) > 800:
            print(original_chunk[:800] + "... [truncated]")
        else:
            print(original_chunk)
        print("-" * 40)
        print(f"Length: {len(original_chunk)} characters\n")
        print("COMPRESSED CHUNK:")
        print("-" * 40)
        print(compressed_chunk)
        print("-" * 40)
        print(f"Length: {len(compressed_chunk)} characters")
        print(f"Compression ratio: {compression_ratio:.2f}%\n")
        avg_ratio = sum(compression_ratios) / len(compression_ratios)
        print(f"Average compression across all chunks: {avg_ratio:.2f}%")
        print(f"Total context length reduction: {evaluation_results['metrics'][comp_type]['avg_compression_ratio']}")
        print("=" * 80)
    print("\n=== COMPRESSION SUMMARY ===\n")
    print(f"{'Technique':<15} {'Avg Ratio':<15} {'Context Length':<15} {'Original Length':<15}")
    print("-" * 60)
    for comp_type, metrics in evaluation_results["metrics"].items():
        print(
            f"{comp_type:<15} {metrics['avg_compression_ratio']:<15} {metrics['total_context_length']:<15} {metrics['original_context_length']:<15}")



pdf_path = "data/AI_Information.pdf"
query = "What are the ethical concerns surrounding the use of AI in decision-making?"
reference_answer = """  
The use of AI in decision-making raises several ethical concerns.  
- Bias in AI models can lead to unfair or discriminatory outcomes, especially in critical areas like hiring, lending, and law enforcement.  
- Lack of transparency and explainability in AI-driven decisions makes it difficult for individuals to challenge unfair outcomes.  
- Privacy risks arise as AI systems process vast amounts of personal data, often without explicit consent.  
- The potential for job displacement due to automation raises social and economic concerns.  
- AI decision-making may also concentrate power in the hands of a few large tech companies, leading to accountability challenges.  
- Ensuring fairness, accountability, and transparency in AI systems is essential for ethical deployment.  
"""
results = evaluate_compression(
    pdf_path=pdf_path,
    query=query,
    reference_answer=reference_answer,
    compression_types=["selective", "summary", "extraction"]
)
visualize_compression_results(results)